function average_TFRs(baselineStrategy, loadIndividual, foi, badSs)
% average_TFRs.m
%
% Computes individual subject and then grand average TFRs in the encoding
% (ENC), interfering (INT), and retrieval (RET) windows for each of the 12
% experimental conditions. 
%
% The parameter "baselineStrategy" can take values 'within' or 'across',
% specifying whether the subject average baselines within each condition or
% averaged across conditions will be used.
%
% The optional parameter "loadIndividual" dictates whether individual 
% subject averages are loaded from file to save time (default = false). 
% This will produce an error if this function has not been previously run 
% with loadIndividual set to FALSE.
%
% The optional parameter "foi" should be a scalar specifying the
% max frequency to average up to. Instead, a value of -1 (default) will
% average across all frequencies in the data. This parameter is mainly 
% useful for the case in which Fourier analyses were carried out with 
% different frequencies for different subjects. NB: This should be a 
% temporary state.
%
% The optional parameter "badSs" should be a [1 x N] array specifying
% numbers (NOT ID's) of subjects to exclude from averaging.

freqStep = 1; % frequency resolution or step size (Hz/step)

% Quality checks
if nargin < 4
    badSs = []; % Assume no bad subjects by default
end
if nargin < 3
    foi = -1;
    if numel(foi) ~= 1 || foi < -1 || foi == 0
        error('Invalid specification of frequencies of interest (foi)')
    end
end
if foi > 0
    foi = foi / freqStep;
end
if nargin < 2
    loadIndividual = false; % compute ind Ss averages by default
end
if nargin < 1
    warning("Baseline strategy not set. Using default of 'across'")
    baselineStrategy = 'across';
end
if ~ismember(lower(baselineStrategy), {'within', 'across'})
    error('invalid baselineStrategy. See function help.')
end

compname = getenv('computername');

if strcmp(compname,'DESKTOP-BHR0AU7')
    rawdir = 'E:\ANL\Experiments\DATA\VAST\RAW\';
    datadir = 'E:\ANL\Experiments\DATA\VAST\PREPROCESSED\TFR\';
    savedir = 'E:\ANL\Experiments\RESULTS\VAST\EEG\';
    addpath('C:\Users\jtjus\Documents\MATLAB\fieldtrip-20180512\fieldtrip-20180512');
    addpath(genpath('E:\MATLAB\CSDtoolbox'));
elseif strcmp(compname,'MSI')
    rawdir = 'E:\ANL\RAW_DATA\VAST\';
    datadir = 'E:\ANL\PREPROCESSED_DATA\VAST\TFR\';
    savedir = 'E:\ANL\RESULTS\VAST\EEG\';
    addpath('D:\MATLAB\fieldtrip-20180809\fieldtrip-20180809');
    addpath(genpath('D:\MATLAB\CSDtoolbox'));
else
    error('need to set up paths for this machine.')
end

allPreprocFiles = dir(datadir);
allPreprocFiles = extractfield(allPreprocFiles, 'name');
allPreprocFiles = allPreprocFiles(3:end);
sIDs = cellfun(@(x) x(1:end-2), allPreprocFiles, 'UniformOutput', false);
sIDs = unique(sIDs);
Nss = length(sIDs);
standardConds = {'as_as','as_at','as_none','at_as','at_at','at_none',...
    'vs_as','vs_at','vs_none','vt_as','vt_at','vt_none'};
standardConds_wInt = standardConds(~contains(standardConds, 'none'));
Nconds = length(standardConds) / 2; % half aud, half vis
memConds = {'as','at','vs','vt'};
standardWins = {'enc', 'int', 'ret'};
standardWins_noenc = standardWins(~contains(standardWins, 'enc'));
Nwins = length(standardWins);
fakeTFR = load('dummy_indSs_TFR.mat');
fakeTFR = fakeTFR.dummy_indSs_TFR;
blcfg = [];
blcfg.baseline = [0 1];
blcfg.baselinetype = 'relative';
blcfg.parameter = 'powspctrm';

if loadIndividual == false
    % Preallocate individual subjects data
    [indSs, indSs_diff] = deal(cell(1, Nss));
    for s = 1:Nss
        indSs{s} = [standardConds; cell(1, length(standardConds))];
        indSs_diff{s} = [standardConds_wInt;...
            cell(1, length(standardConds_wInt))];
        for c = 1:length(standardConds)
            indSs{s}{2,c} = [standardWins; cell(1, Nwins)];
        end
        for c = 1:length(standardConds_wInt)
            % Only going to compute diffs in INT and RET windows
            indSs_diff{s}{2,c} =...
                [standardWins_noenc; cell(1, length(standardWins_noenc))];
        end
    end

    for s = 1:Nss
        %% Load average baselines, average across aud and vis sessions
        BL_a = load(strcat(datadir,sIDs{s},'_a\',sIDs{s},...
            '_a_avg_TFR_BL.mat'), 'BL');
        BL_a = BL_a.BL;
        if foi > 0 % specifying an upper bound
            if size(BL_a, 2) > foi
                BL_a = BL_a(:,1:foi,:);
            end
        end
        BL_v = load(strcat(datadir,sIDs{s},'_v\',sIDs{s},...
            '_v_avg_TFR_BL.mat'), 'BL');
        BL_v = BL_v.BL;
        if foi > 0 % specifying an upper bound
            if size(BL_v, 2) > foi
                BL_v = BL_v(:,1:foi,:);
            end
        end
        BL = cat(4, BL_a, BL_v);
        BL = squeeze(mean(BL, 4));
        BLlen = size(BL, 3); % length in TFR samples
        
        %% AUDITORY DATA
        % load the TFRs
        load(strcat(datadir,sIDs{s},'_a\',sIDs{s},'_a_TFR_bywin.mat'));
        % adjust parameters so it recognizes data as averaged across trials
        [tfr_enc.cfg.keeptrials, tfr_int.cfg.keeptrials,...
            tfr_ret.cfg.keeptrials] = deal('no');
        [tfr_enc.dimord, tfr_int.dimord, tfr_ret.dimord] =...
            deal('chan_freq_time');
        % load the bad epochs reference
        load(strcat(datadir,sIDs{s},'_a\',sIDs{s},'_a_STEP9_vars.mat'),...
            'badEpochs');
        % load the behavioral data
        a_files = dir(strcat(rawdir,sIDs{s}));
        a_files = extractfield(a_files, 'name');
        a_files = a_files(3:end);
        a_behav = a_files{contains(a_files, '_a') & contains(a_files, '.mat')};
        load(strcat(rawdir, sIDs{s}, '\', a_behav), 'cfg');

        % assign conditions to trials
        trlConds = cell(1, cfg.nBlocks * cfg.nTrials);
        for b = 1:cfg.nBlocks
            m = cfg.m1{b};
            d = cfg.d{b};
            int = cfg.dist{b};
            currcond = strcat(m, d, '_', int);
            inds = ((b-1)*cfg.nTrials) + 1 : b * cfg.nTrials;
            trlConds(inds) = {currcond};
        end

        for w = 1:Nwins
            currwin = standardWins{w};
            eval(strcat('currbad = badEpochs.',currwin,';'))
            goodtrl_conds = trlConds(~currbad);
            for c = 1:Nconds
                conds_restrict = standardConds(contains(standardConds, 'as_')...
                    | contains(standardConds, 'at_'));
                currcond = conds_restrict{c};
                currtrls = strcmp(goodtrl_conds, currcond);
                % get averaged power spectra -- not saving whole structures
                if foi == -1 % keep all freqs
                    eval(strcat('temp = squeeze(nanmean(tfr_',...
                        currwin,'.powspctrm(currtrls,:,:,:), 1));'));
                else % specifying an upper freq bound
                    eval(strcat('temp = squeeze(nanmean(tfr_',...
                        currwin,'.powspctrm(currtrls,:,1:foi,:), 1));'));
                end
                % replace the grand average baseline if this is requested
                if strcmp(baselineStrategy, 'across')
                    temp(:,:,1:BLlen) = BL;
                end
                % get rid of resilient last sample NaNs (if present)
                if ismember(size(temp, 3),...
                            find(isnan(squeeze(sum(sum(temp, 1), 2)))))
                    temp = temp(:,:,1:end-1);
                end
                if sum(isnan(temp(:))) > 0
                    warning("NaNs still present in subject %s, cond %s, win %s \n",...
                        sIDs{s}, standardConds{c}, standardWins{w})
                end
                % normalize the individual subject data by the baseline
                ssTFR = fakeTFR;
                ssTFR.powspctrm = temp;
                ssTFR.time = ssTFR.time(1:size(temp, 3));
                ssTFR = ft_freqbaseline(blcfg, ssTFR);
                % plug into the individual subjects average cell array
                eval(strcat('indSs{s}{2,strcmp(standardConds, currcond)}',...
                    '{2,w} = ssTFR.powspctrm;')) 
            end % condition loop
        end % trial window loop

        %% VISUAL DATA
        % load the TFRs
        load(strcat(datadir,sIDs{s},'_v\',sIDs{s},'_v_TFR_bywin.mat'));
        % adjust parameters so it recognizes data as averaged across trials
        [tfr_enc.cfg.keeptrials, tfr_int.cfg.keeptrials,...
            tfr_ret.cfg.keeptrials] = deal('no');
        [tfr_enc.dimord, tfr_int.dimord, tfr_ret.dimord] =...
            deal('chan_freq_time');
        % load the bad epochs reference
        load(strcat(datadir,sIDs{s},'_v\',sIDs{s},'_v_STEP9_vars.mat'),...
            'badEpochs');
        % load the behavioral data
        v_files = dir(strcat(rawdir,sIDs{s}));
        v_files = extractfield(v_files, 'name');
        v_files = v_files(3:end);
        v_behav = v_files{contains(v_files, '_v') & contains(v_files, '.mat')};
        load(strcat(rawdir, sIDs{s}, '\', v_behav), 'cfg');

        % assign conditions to trials
        trlConds = cell(1, cfg.nBlocks * cfg.nTrials);
        for b = 1:cfg.nBlocks
            m = cfg.m1{b};
            d = cfg.d{b};
            int = cfg.dist{b};
            currcond = strcat(m, d, '_', int);
            inds = ((b-1)*cfg.nTrials) + 1 : b * cfg.nTrials;
            trlConds(inds) = {currcond};
        end

        for w = 1:Nwins
            currwin = standardWins{w};
            eval(strcat('currbad = badEpochs.',currwin,';'))
            goodtrl_conds = trlConds(~currbad);
            for c = 1:Nconds
                conds_restrict = standardConds(contains(standardConds, 'vs_')...
                    | contains(standardConds, 'vt_'));
                currcond = conds_restrict{c};
                currtrls = strcmp(goodtrl_conds, currcond);
                % get averaged power spectra -- not saving whole structures
                if foi == -1 % keep all freqs
                    eval(strcat('temp = squeeze(nanmean(tfr_',...
                        currwin,'.powspctrm(currtrls,:,:,:), 1));'));
                else % specifying an upper freq bound
                    eval(strcat('temp = squeeze(nanmean(tfr_',...
                        currwin,'.powspctrm(currtrls,:,1:foi,:), 1));'));
                end
                % replace the grand average baseline if this is requested
                if strcmp(baselineStrategy, 'across')
                    temp(:,:,1:BLlen) = BL;
                end
                % get rid of resilient last sample NaNs (if present)
                if ismember(size(temp, 3),...
                            find(isnan(squeeze(sum(sum(temp, 1), 2)))))
                    temp = temp(:,:,1:end-1);
                end
                if sum(isnan(temp(:))) > 0
                    warning("NaNs still present in subject %s, cond %s, win %s \n",...
                        sIDs{s}, standardConds{c}, standardWins{w})
                end
                % normalize the individual subject data by the baseline
                ssTFR = fakeTFR;
                ssTFR.powspctrm = temp;
                ssTFR.time = ssTFR.time(1:size(temp, 3));
                ssTFR = ft_freqbaseline(blcfg, ssTFR);
                % plug into the individual subjects average cell array
                eval(strcat('indSs{s}{2,strcmp(standardConds, currcond)}',...
                    '{2,w} = ssTFR.powspctrm;')) 
            end % condition loop
        end % trial window loop
    
        
        %% Compute Int-NoInt relative TFRs
        
        condCt = 1;
        for tempcond = standardConds_wInt
            cInt = tempcond{1};
            wInt_ind = contains(standardConds, cInt);
            currTask_noInt = strcat(cInt(1:2), '_none');
            noInt_ind = contains(standardConds, currTask_noInt);
            
            % INT window
            currdata_wInt = indSs{s}{2, wInt_ind}...
                {2, contains(standardWins, 'int')};
            currdata_noInt = indSs{s}{2, noInt_ind}...
                {2, contains(standardWins, 'int')};
            winLen = min([size(currdata_wInt,3), size(currdata_noInt,3)]);
            currdata_wInt = currdata_wInt(:,:,1:winLen);
            currdata_noInt = currdata_noInt(:,:,1:winLen);
%             % ratio at each time-freq-chan point
%             currdiff = currdata_wInt ./ currdata_noInt;
            % simple difference
            currdiff = currdata_wInt - currdata_noInt;
            indSs_diff{s}{2, condCt}...
                {2, contains(standardWins_noenc, 'int')} = currdiff;
            
            % RET window
            currdata_wInt = indSs{s}{2, wInt_ind}...
                {2, contains(standardWins, 'ret')};
            currdata_noInt = indSs{s}{2, noInt_ind}...
                {2, contains(standardWins, 'ret')};
            winLen = min([size(currdata_wInt,3), size(currdata_noInt,3)]);
            currdata_wInt = currdata_wInt(:,:,1:winLen);
            currdata_noInt = currdata_noInt(:,:,1:winLen);
%             % ratio at each time-freq-chan point
%             currdiff = currdata_wInt ./ currdata_noInt;
            % simple difference
            currdiff = currdata_wInt - currdata_noInt;
            indSs_diff{s}{2, condCt}...
                {2, contains(standardWins_noenc, 'ret')} = currdiff;
        
            condCt = condCt + 1;
        end

    end

    %% Save
    save(strcat(savedir,'indSs_TFR.mat'), 'indSs', 'indSs_diff', '-v7.3')
    
else % Load the individual subjects' averages from a previous run
    load(strcat(savedir,'indSs_TFR.mat'))
end


%% Average 

% Load dummy FieldTrip struct
dummy_indSs_TFR = load('dummy_indSs_TFR.mat');
dummy_indSs_TFR = dummy_indSs_TFR.dummy_indSs_TFR;
dummy_indSs_TFR = rmfield(dummy_indSs_TFR, 'cumtapcnt');
% adjust the frequency content according to foi
if foi == -1 % use all freqs in the indSs data
    dummy_indSs_TFR.freq = 1 : freqStep : size(indSs{1}{2,1}{2,1}, 2);
else % max frequency is specified
    dummy_indSs_TFR.freq = 1 : freqStep : foi;
end

% Preallocate average data
avg_TFR = [standardConds; cell(1, length(standardConds))];
avg_TFR_diff = [standardConds_wInt; cell(1, length(standardConds_wInt))];
for c = 1:length(standardConds)
    avg_TFR{2,c} = [standardWins; cell(1, Nwins)];
end
for c = 1:length(standardConds_wInt)
    avg_TFR_diff{2,c} = [standardWins_noenc;...
        cell(1, length(standardWins_noenc))];
end

% Use string evaluation to flexibly compute grand averages
goodSs = 1:Nss;
goodSs(ismember(goodSs, badSs)) = [];
Nss_adj = length(goodSs);
cfg = [];
cfg.parameter = 'powspctrm';

%%% --- First, normal condition-specific TFR averages ---
for c = 1:length(standardConds)
    for w = 1:Nwins
        
        currIndiv = cell(1, Nss_adj);
        currIndiv(:) = {dummy_indSs_TFR};
        
        % Find the minimum length this condition/time window
        lens = zeros(1, Nss_adj);
        ct = 1;
        for ii = goodSs
            lens(ct) = size(indSs{ii}{2,c}{2,w}, 3);
            ct = ct + 1;
        end
        minlen = min(lens);
                    
        % Loop over subjects to...
        % 1) populate fieldtrip structs with actual power spectra 
        % 2) truncate the timebase
        % 3) build the averaging strings to be evaluated
        avgstr = strcat('avg_TFR{2,c}{2,w} =',...
                        'ft_freqgrandaverage(cfg,');
        ct = 1;
        for s = goodSs
            currIndiv{ct}.powspctrm = indSs{s}{2,c}{2,w}(:,...
                1:length(dummy_indSs_TFR.freq), 1:minlen);
            currIndiv{ct}.time = currIndiv{ct}.time(1:minlen);
            if s ~= goodSs(end)
                avgstr = strcat(avgstr, 'currIndiv{', num2str(ct), '},');
            else
                avgstr = strcat(avgstr, 'currIndiv{', num2str(ct), '});');
            end
            ct = ct + 1;
        end
        % Evaluate this loop's average string
        eval(avgstr)
        
        % Remove the massive "previous" field of average cfg, if present
        if isfield(avg_TFR{2,c}{2,w}.cfg, 'previous')
            avg_TFR{2,c}{2,w}.cfg =...
                rmfield(avg_TFR{2,c}{2,w}.cfg, 'previous');
        end
%         % Baseline correct the condition average data
%         blcfg = [];
%         blcfg.baseline = [0 1];
%         blcfg.baselinetype = 'relative';
%         blcfg.parameter = 'powspctrm';
%         avg_TFR{2,c}{2,w} = ft_freqbaseline(blcfg, avg_TFR{2,c}{2,w});

    end % window loop
end % condition loop

%%% --- Then, difference TFR averages ---
for c = 1:length(standardConds_wInt)
    for w = 1:length(standardWins_noenc)
        
        currIndiv = cell(1, Nss_adj);
        currIndiv(:) = {dummy_indSs_TFR};
        
        % Find the minimum length this condition/time window
        lens = zeros(1, Nss_adj);
        ct = 1;
        for ii = goodSs
            lens(ct) = size(indSs_diff{ii}{2,c}{2,w}, 3);
            ct = ct + 1;
        end
        minlen = min(lens);
                    
        % Loop over subjects to...
        % 1) populate fieldtrip structs with actual power spectra 
        % 2) truncate the timebase
        % 3) build the averaging strings to be evaluated
        avgstr = strcat('avg_TFR_diff{2,c}{2,w} =',...
                        'ft_freqgrandaverage(cfg,');
        ct = 1;
        for s = goodSs
            currIndiv{ct}.powspctrm = indSs_diff{s}{2,c}{2,w}(:,...
                1:length(dummy_indSs_TFR.freq), 1:minlen);
            currIndiv{ct}.time = currIndiv{ct}.time(1:minlen);
            if s ~= goodSs(end)
                avgstr = strcat(avgstr, 'currIndiv{', num2str(ct), '},');
            else
                avgstr = strcat(avgstr, 'currIndiv{', num2str(ct), '});');
            end
            ct = ct + 1;
        end
        % Evaluate this loop's average string
        eval(avgstr)
        
        % Remove the massive "previous" field of average cfg, if present
        if isfield(avg_TFR_diff{2,c}{2,w}.cfg, 'previous')
            avg_TFR_diff{2,c}{2,w}.cfg =...
                rmfield(avg_TFR_diff{2,c}{2,w}.cfg, 'previous');
        end

    end % window loop
end % condition loop

save(strcat(savedir,'avg_TFR.mat'), 'avg_TFR', 'avg_TFR_diff', '-v7.3')





        
        
        
        
        
        
